{
  "cells": [
    {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nTt0UNQbk_Td"
      },
      "source": [
        "# Example of computing finite width NTK of an FCN on CIFAR-10 inputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dvXMUSdFjCqq"
      },
      "source": [
        "Tested on NVIDIA T4."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sl2JyRE1hK-z"
      },
      "source": [
        "# Imports and setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 287,
          "status": "ok",
          "timestamp": 1655534554432,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "HunkuGjSr63O",
        "outputId": "e56b6a2d-161e-46e4-a3ee-9c0ce53e8234"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "GPU 0: Tesla T4 (UUID: GPU-70c5645d-1ae8-d42c-8b35-ca8a58b81237)\n"
          ]
        }
      ],
      "source": [
        "!nvidia-smi -L"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bmmbCtfh76RS"
      },
      "outputs": [],
      "source": [
        "# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`\n",
        "!pip install -q --upgrade pip\n",
        "!pip install -q jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
        "!pip install -q git+https://www.github.com/google/neural-tangents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qM24rCkDI3F8"
      },
      "outputs": [],
      "source": [
        "from jax import jit\n",
        "from jax import numpy as jnp\n",
        "from jax import random\n",
        "\n",
        "import neural_tangents as nt\n",
        "from neural_tangents import stax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tt0rL4ycp4ZB"
      },
      "source": [
        "# Defining a simple FCN model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lPh5LGz9JBK_"
      },
      "outputs": [],
      "source": [
        "def get_ntk_fns(O: int):\n",
        "  # Define an FCN.\n",
        "  init_fn, apply_fn, _ = stax.serial(\n",
        "      stax.Dense(2048),\n",
        "      stax.Relu(),\n",
        "      stax.Dense(2048),\n",
        "      stax.Relu(),\n",
        "      stax.Dense(2048),\n",
        "      stax.Relu(),\n",
        "      stax.Dense(O)\n",
        "  )\n",
        "\n",
        "  kwargs = dict(\n",
        "      f=apply_fn,\n",
        "      trace_axes=(),\n",
        "      vmap_axes=0\n",
        "  )\n",
        "\n",
        "  # Different NTK implementations\n",
        "  jacobian_contraction = jit(nt.empirical_ntk_fn(\n",
        "      **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION))\n",
        "  ntvp = jit(nt.empirical_ntk_fn(\n",
        "      **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS))\n",
        "  str_derivatives = jit(nt.empirical_ntk_fn(\n",
        "      **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES))\n",
        "  auto = jit(nt.empirical_ntk_fn(\n",
        "      **kwargs, implementation=nt.NtkImplementation.AUTO))\n",
        "\n",
        "  # Parameters \\theta\n",
        "  _, params = init_fn(random.PRNGKey(0), x1.shape)\n",
        "  return params, (jacobian_contraction, ntvp, str_derivatives, auto)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0lWFC3QEgao4"
      },
      "source": [
        "# Benchmark"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kbExWKkUg9ew"
      },
      "source": [
        "Structured derivatives compute NTK fastest. NTK-vector products also provide a speedup, due to a cheap forward pass relative to parameters size."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bwhIZWqxKTlt"
      },
      "outputs": [],
      "source": [
        "O = 8\n",
        "N = 16\n",
        "\n",
        "# Input images x\n",
        "input_shape = (3072,)\n",
        "k1, k2 = random.split(random.PRNGKey(1), 2)\n",
        "x1 = random.normal(k1, (N, *input_shape))\n",
        "x2 = random.normal(k2, (N, *input_shape))\n",
        "\n",
        "params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 398,
          "status": "ok",
          "timestamp": 1655534571392,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "zObT8WnPggFo",
        "outputId": "ef32ed5e-f51c-4f32-feb1-5598fc6bc9af"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(16, 16, 8, 8)\n"
          ]
        }
      ],
      "source": [
        "# Jacobian contraction\n",
        "k_1 = ntk_fn_jacobian_contraction(x1, x2, params)\n",
        "print(k_1.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 536,
          "status": "ok",
          "timestamp": 1655534571924,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "FW9gJJ4qggFp",
        "outputId": "b38d0f40-8cc2-4e54-ce1b-5286648d069b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(16, 16, 8, 8)\n"
          ]
        }
      ],
      "source": [
        "# NTK-vector products\n",
        "k_2 = ntk_fn_ntvp(x1, x2, params)\n",
        "print(k_2.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 644,
          "status": "ok",
          "timestamp": 1655534572565,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "gFeWnqGQggFp",
        "outputId": "c2c1dfd1-2355-451f-9a82-8c1d5b2662e2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(16, 16, 8, 8)\n"
          ]
        }
      ],
      "source": [
        "# Structured derivatives\n",
        "k_3 = ntk_fn_str_derivatives(x1, x2, params)\n",
        "print(k_3.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 4,
          "status": "ok",
          "timestamp": 1655534572565,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "Q63v1L1aggFp",
        "outputId": "ebaa3170-10ed-4098-c056-24e222845b68"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "8.350792e-06 8.350792e-06 3.7114617e-06\n"
          ]
        }
      ],
      "source": [
        "# Make sure kernels agree.\n",
        "print(\n",
        "    jnp.max(jnp.abs(k_1 - k_2)) / jnp.mean(jnp.abs(k_1)), \n",
        "    jnp.max(jnp.abs(k_1 - k_3)) / jnp.mean(jnp.abs(k_1)),\n",
        "    jnp.max(jnp.abs(k_2 - k_3)) / jnp.mean(jnp.abs(k_2))\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 2326,
          "status": "ok",
          "timestamp": 1655534574888,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "Ux7AEZ9fggFp",
        "outputId": "08504a47-ecd7-4eec-e5d4-8ce604d3ec02"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "impl=1, flops=3765027328.0\n",
            "impl=2, flops=1916764544.0\n",
            "impl=3, flops=2670960.0\n",
            "(16, 16, 8, 8)\n"
          ]
        }
      ],
      "source": [
        "# test {\"skip\": true}\n",
        "# Selects best method based on FLOPs at first call / compilation.\n",
        "# Takes about 3x more time to compile.\n",
        "# WARNING: due to an XLA issue, currently only works correctly on TPUs!\n",
        "# Wrong FLOPs for CPU/GPU of JITted functions.\n",
        "k_0 = ntk_fn_auto(x1, x2, params)\n",
        "print(k_0.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 1631,
          "status": "ok",
          "timestamp": 1655534576513,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "diP7nkBuggFp",
        "outputId": "014eb040-f8c3-417e-b1ec-9197c26655af"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 5: 243 ms per loop\n"
          ]
        }
      ],
      "source": [
        "# test {\"skip\": true}\n",
        "%%timeit\n",
        "ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 5057,
          "status": "ok",
          "timestamp": 1655534581568,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "wehCdvi2ggFp",
        "outputId": "e7c90231-655e-4d56-86e7-87b4020ef741"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "10 loops, best of 5: 81.4 ms per loop\n"
          ]
        }
      ],
      "source": [
        "# test {\"skip\": true}\n",
        "%%timeit\n",
        "# 3X faster.\n",
        "ntk_fn_ntvp(x1, x2, params).block_until_ready()  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 2111,
          "status": "ok",
          "timestamp": 1655534583662,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "Yrm53akVggFp",
        "outputId": "158d44a0-c68a-4545-fa91-ff6b1c8432f4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "100 loops, best of 5: 3.46 ms per loop\n"
          ]
        }
      ],
      "source": [
        "# test {\"skip\": true}\n",
        "%%timeit\n",
        "# 70X faster.\n",
        "ntk_fn_str_derivatives(x1, x2, params).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 2159,
          "status": "ok",
          "timestamp": 1655534585818,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "P1QtkBqLggFp",
        "outputId": "cd7a3361-c548-47a4-f4b2-44aff2ba492c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "100 loops, best of 5: 3.44 ms per loop\n"
          ]
        }
      ],
      "source": [
        "# test {\"skip\": true}\n",
        "%%timeit \n",
        "# On TPU should match the fastest method.\n",
        "# On GPU/CPU, currently is broken, and may not be the fastest.\n",
        "ntk_fn_auto(x1, x2, params).block_until_ready()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Sl2JyRE1hK-z"
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "empirical_ntk_fcn.ipynb",
      "provenance": [
        {
          "file_id": "19BhT0vMzkX7q07p_zYZQHgVXKigfWNO8",
          "timestamp": 1655442076432
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
